{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "We will build the real news vs fake news detection engine. We want to demonstrate how this pipeline can be adapted to your organization's specific needs. Instead of using a pre-built dataset, we will download a dataset from Kaggle and utilize it in our fine-tuning process. This approach will help illustrate how the pipeline can be tailored to work with custom datasets in real-world applications.\n",
        "Here's an outline of the fine-tuning process\n",
        "1. Import required libraries and packages\n",
        "\n",
        "2. Load the dataset. Download the data from kaggle and save it on your drive.\n",
        "3. Load pre-trained BERT tokenizer:\n",
        "\n",
        "\n",
        "4. Prepare the dataset: \n",
        "\n",
        "\n",
        "  * Tokenize the text using the BERT tokenizer\n",
        "  * Create attention masks\n",
        " * Split the dataset into training and validation sets\n",
        "  * Create a custom PyTorch dataset class (TextClassificationDataset)\n",
        "  * Instantiate the custom dataset for both training and validation sets\n",
        "  * Create PyTorch DataLoader\n",
        "  \n",
        "4. Load a pre-trained BERT model for sequence classification using the Hugging Face Transformers library\n",
        "5. Setup Accelarator environment\n",
        "6. Fine-tune the model:\n",
        "\n",
        "7. Evaluate the model:\n",
        "  *Calculate  metrics, such as F1 score, recall, and precision\n",
        "8. Inference:\n",
        "\n",
        "  * Create a function to perform inference on new text input\n",
        " * Tokenize the input text and convert it to the required format\n",
        " * Perform inference using the fine-tuned model\n",
        " * Interpret the model's output and return the predicted class"
      ],
      "metadata": {
        "id": "y-o9q5TSzR8a"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "# 1. Import required libraries and packages\n",
        " "
      ],
      "metadata": {
        "id": "WrC8FHao2aon"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import pandas as pd\n",
        "\n",
        "from sklearn.model_selection import train_test_split\n",
        "\n",
        "\n",
        "from accelerate import Accelerator\n",
        "\n",
        "import torch\n",
        "from torch.utils.data import DataLoader, RandomSampler, SequentialSampler\n",
        "\n",
        "from tqdm import tqdm\n",
        "\n",
        "from transformers import AutoTokenizer\n",
        "from transformers import AutoModelForSequenceClassification\n",
        "from transformers import AdamW\n",
        "from transformers import get_scheduler"
      ],
      "metadata": {
        "id": "rOit_UrJdw9-"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "**Note:** MPS=> Apple's Metal Performance Shaders (MPS) is a framework that provides highly optimized, low-level GPU-accelerated functions for deep learning, image processing, and other compute-intensive tasks."
      ],
      "metadata": {
        "id": "h2Z31EI_y08Q"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def get_device():\n",
        "  device=\"cpu\"\n",
        "  if torch.cuda.is_available():\n",
        "    device=\"cuda\"\n",
        "  elif  torch.backends.mps.is_available():\n",
        "    device='mps'\n",
        "  else:\n",
        "    device=\"cpu\"\n",
        "  return device\n",
        "\n",
        "\n",
        "device = get_device()\n",
        "print(device) "
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "XdAE-Xc4cg9D",
        "outputId": "d7ef1852-9464-4076-854c-2d5b45d2a94b"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "mps\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# 2. Load Data\n",
        "1. Reading data from two CSV files: True.csv (real news) and Fake.csv (fake news)\n",
        "2. Cleaning and preprocessing the data in each CSV file\n",
        "3. Concatenating both dataframes into a single dataframe\n",
        "4. The resulting dataframe contains two columns: 'text' for the news content and 'label' for its corresponding category (real or fake)"
      ],
      "metadata": {
        "id": "aOT9-e3R15hG"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "real=pd.read_csv('/Users/premtimsina/Documents/bpbbook/chapter4/dataset/True.csv')\n",
        "fake=pd.read_csv('/Users/premtimsina/Documents/bpbbook/chapter4/dataset/Fake.csv')\n"
      ],
      "metadata": {
        "id": "f2YociOdJ4CD"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "real = real.drop(['title','subject','date'], axis=1)\n",
        "real['label']=1.0\n",
        "fake = fake.drop(['title','subject','date'], axis=1)\n",
        "fake['label']=0.0\n",
        "dataframe=pd.concat([real, fake], axis=0, ignore_index=True)\n"
      ],
      "metadata": {
        "id": "IERhOScoLJM9"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "df = dataframe.sample(frac=0.1).reset_index(drop=True)\n",
        "print(df.head(20))\n",
        "print(len(df[df['label']==1.0]))\n",
        "print(len(df[df['label']==0.0]))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "yBU1b0_syl00",
        "outputId": "df0fbc86-91e6-4d46-af49-312160b52436"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "                                                 text  label\n",
            "0   Donald Trump turned a tragedy into an opportun...    0.0\n",
            "1   MARAWI CITY, Philippines (Reuters) - With a gr...    1.0\n",
            "2   This could be YUGE news for conservatives who ...    0.0\n",
            "3   WASHINGTON (Reuters) - The U.S. government is ...    1.0\n",
            "4   MOGADISHU (Reuters) - More than 300 people wer...    1.0\n",
            "5   Two white guys living in a state where 96% of ...    0.0\n",
            "6   WASHINGTON (Reuters) - U.S. Defense Secretary ...    1.0\n",
            "7                                                        0.0\n",
            "8   EXETER, N.H./MILFORD, N.H. (Reuters) - U.S. Se...    1.0\n",
            "9   (Reuters) - New Jersey Governor Chris Christie...    1.0\n",
            "10  The funniest part of the video comes when MSNB...    0.0\n",
            "11  Newsweek reporter Kurt Eichenwald came under f...    0.0\n",
            "12  https://fedup.wpengine.com/wp-content/uploads/...    0.0\n",
            "13  The Donald Trump campaign is relying heavily o...    0.0\n",
            "14  BAGHDAD (Reuters) - Arms provided by the Unite...    1.0\n",
            "15  WASHINGTON (Reuters) - A U.S. House of Represe...    1.0\n",
            "16  21st Century Wire says Our weekly documentary ...    0.0\n",
            "17  WASHINGTON (Reuters) - As the Senate began a t...    1.0\n",
            "18  Bill Clinton campaigned at a church in Flint, ...    0.0\n",
            "19  What Catholic college campus would be complete...    0.0\n",
            "2118\n",
            "2372\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "#3.  Load Tokenizer:\n",
        "1. We are using the `bert-base-uncased` tokenizer. We also need to use the corresponding model"
      ],
      "metadata": {
        "id": "UEE-QEUW3YyM"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "\n",
        "tokenizer = AutoTokenizer.from_pretrained(\"bert-base-uncased\")"
      ],
      "metadata": {
        "id": "AFP8kHXrzfSL"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# 4. Prepare Data\n",
        "The data preparation process for BERT-based uncased models involves tokenizing the text, mapping tokens to `input_ids`, creating attention masks `attention_mask`, , and preparing the labels tensor `labels`. Each element of Dataset Class should be dictionary of following structure.\n",
        "\n",
        "```\n",
        "{'input_ids': torch.Tensor(),'attention_mask':torch.Tensor(), 'labels': torch.Tensor()  }\n",
        "```\n",
        "1. Tokenization: The text input should be tokenized into subwords using BERT's WordPiece tokenizer. This tokenizer converts the text into a format that BERT can understand.\n",
        "\n",
        "2. `input_ids`: Each token from the tokenized text needs to be mapped to an ID using BERT's vocabulary. The resulting input IDs should be in the form of a tensor or array, usually of shape (batch_size, max_sequence_length).\n",
        "3. `attention_mask`: The attention mask is used to differentiate between the actual tokens and padding tokens. It has the same shape as the input IDs tensor, i.e., (batch_size, max_sequence_length). The mask has 1s for actual tokens and 0s for padding tokens.\n",
        "4. `labels`: The labels tensor contains the true class or value for each example in the dataset. It usually has a shape of (batch_size,). For classification tasks, these labels are one-hot-encoded labels"
      ],
      "metadata": {
        "id": "4SQ3sVVn3tSi"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# this is just creating list of tuples. Each tupe has (text, label)\n",
        "data=list(zip(df['text'].tolist(), df['label'].tolist()))\n",
        "\n",
        "# This function takes list of Texts, and Labels as Parameter\n",
        "# This function return input_ids, attention_mask, and labels_out\n",
        "def tokenize_and_encode(texts, labels):\n",
        "    input_ids, attention_masks, labels_out = [], [], []\n",
        "    for text, label in zip(texts, labels):\n",
        "        encoded = tokenizer.encode_plus(text, max_length=512, padding='max_length', truncation=True)\n",
        "        input_ids.append(encoded['input_ids'])\n",
        "        attention_masks.append(encoded['attention_mask'])\n",
        "        labels_out.append(label)\n",
        "    return torch.tensor(input_ids), torch.tensor(attention_masks), torch.tensor(labels_out)\n",
        "\n",
        "# seprate the tuples\n",
        "# generate two lists: a) containing texts, b) containing labels\n",
        "texts, labels = zip(*data)\n",
        "\n",
        "# train, validation split\n",
        "train_texts, val_texts, train_labels, val_labels = train_test_split(texts, labels, test_size=0.2)\n",
        "\n",
        "# tokenization\n",
        "train_input_ids, train_attention_masks, train_labels = tokenize_and_encode(train_texts, train_labels)\n",
        "val_input_ids, val_attention_masks, val_labels = tokenize_and_encode(val_texts, val_labels)\n",
        "\n",
        "\n"
      ],
      "metadata": {
        "id": "jbJiQDfw9tUI"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "**It's always good to review the data**\n",
        "1. input_ids\n",
        "  * `0` token value means padded token\n",
        "2. attention_mask\n",
        "  * `1`: corresponding token is real token\n",
        "  * `0`: corresponding token is padded token"
      ],
      "metadata": {
        "id": "A0sp2xdP7UPG"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "print('train_input_ids ',train_input_ids[0].shape ,train_input_ids[0], '\\n'\n",
        "      'train_attention_masks ', train_attention_masks[0] ,train_attention_masks[0], '\\n'\n",
        "      'train_labels', train_labels[0])"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "K9VcMmQEeQ9H",
        "outputId": "d943ed4e-5694-4345-9ff3-cbde0b561bef"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "train_input_ids  torch.Size([512]) tensor([  101,   153,  9919, 11612, 11840,  2069,   117,  3658,   113, 11336,\n",
            "        27603,   114,   118,  1109, 19850,  9651, 17677,  1372,  1163,  1122,\n",
            "         2446,  1149,  1126,  2035,  1113,  3625,  1115,  1841,  1421,  1234,\n",
            "         1107,  1103, 10059,  1877, 20268,  6469,   119, 12008,  8167, 20059,\n",
            "          118,   174,   118, 17677,  3658,   113,   157, 17433,   114, 15465,\n",
            "        12460,   148, 26033, 11192,  7192,  1163,  1103, 19342,  8100,  1126,\n",
            "        22421, 12154,  4442,  1106,  4010,  2699,  4675,  1107,  1103, 21519,\n",
            "         2149,  5571,  1298,  1115,  1110,  1226,  1104,  1103,  3467,  1193,\n",
            "        24930, 25685, 25924, 24252, 20371,   113,  6820,  9159,   114,   119,\n",
            "         1109,  9232,  4168,   170,  3686,  1107,  1103,  7085, 13889,  1298,\n",
            "         1104, 21519,  2149,   117,  3646,  1300,  2699,  4675,  1105,  1126,\n",
            "         2078,  1121,  1103,  6688,  3469,   117,  1469,  1433,  3509,  1163,\n",
            "          119, 21519,  2149,  1110,  1141,  1104,  1103, 19585,  2737, 19972,\n",
            "        10059,  4001,  1485,  1103, 13099,  3070,   119,  4354,  1107,  1103,\n",
            "         1805,  1144,  4725,  1107,  2793,  1201,  1170,  1103,  9651,  1764,\n",
            "         5378,  5810,  1116,  1175,  1222,  1103,  2393, 20205,   118,  5128,\n",
            "         9651, 17677,   117,  1134,  1110,  1737,  1103,  1583,   188,  4583,\n",
            "         2699,  4433,   119,  1252, 19342,  2760,  1106,  2016,  3690,  1107,\n",
            "         6820,  9159,   117,  1134,  2606,  1228,   118, 22379,  1111, 16714,\n",
            "         1105,  1110,  1737,  1141,  1104,  1103,  1211, 23332,  2192,  1104,\n",
            "         1103,  4272,   118,  4223,  3790,  1104, 19980,  1550,  1234,   119,\n",
            "          102,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
            "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
            "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
            "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
            "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
            "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
            "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
            "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
            "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
            "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
            "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
            "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
            "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
            "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
            "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
            "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
            "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
            "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
            "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
            "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
            "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
            "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
            "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
            "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
            "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
            "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
            "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
            "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
            "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
            "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
            "            0,     0]) \n",
            "train_attention_masks  tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
            "        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
            "        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
            "        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
            "        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
            "        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
            "        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
            "        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
            "        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0,\n",
            "        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
            "        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
            "        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
            "        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
            "        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
            "        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
            "        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
            "        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
            "        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
            "        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
            "        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
            "        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
            "        0, 0, 0, 0, 0, 0, 0, 0]) tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
            "        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
            "        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
            "        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
            "        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
            "        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
            "        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
            "        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
            "        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0,\n",
            "        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
            "        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
            "        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
            "        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
            "        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
            "        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
            "        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
            "        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
            "        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
            "        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
            "        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
            "        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
            "        0, 0, 0, 0, 0, 0, 0, 0]) \n",
            "train_labels tensor(1.)\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "### TextClassificationDataset\n",
        "1. For tunning `bert-based-uncased`: each item of Dataset must be of type dictionary with at following  keys:\n",
        "  * input_ids\n",
        "  * attention_mask\n",
        "  * labels\n",
        "2. Thus,  `__getitem__`  should return dictionary of following structure:\n",
        "```\n",
        "{\n",
        "            'input_ids': self.input_ids[idx],\n",
        "            'attention_mask': self.attention_masks[idx],\n",
        "            'labels': self.one_hot_labels[idx]\n",
        "        }\n",
        "```\n",
        "3. one_hot_encode method: A static method that takes in targets (labels) and num_classes as arguments. It converts the given targets into one-hot encoded tensors. The method first converts the targets to long tensors and then initializes a zero tensor of shape (number of samples, num_classes). The scatter_ function is used to place 1.0 in the appropriate position for each sample's label, resulting in a one-hot encoded tensor."
      ],
      "metadata": {
        "id": "ZTyeVXy7t-T1"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "class TextClassificationDataset(torch.utils.data.Dataset):\n",
        "    def __init__(self, input_ids, attention_masks, labels, num_classes=2):\n",
        "        self.input_ids = input_ids\n",
        "        self.attention_masks = attention_masks\n",
        "        self.labels = labels\n",
        "        self.num_classes = num_classes\n",
        "        self.one_hot_labels = self.one_hot_encode(labels, num_classes)\n",
        "\n",
        "    def __len__(self):\n",
        "        return len(self.input_ids)\n",
        "\n",
        "    def __getitem__(self, idx):\n",
        "        return {\n",
        "            'input_ids': self.input_ids[idx],\n",
        "            'attention_mask': self.attention_masks[idx],\n",
        "            'labels': self.one_hot_labels[idx]\n",
        "        }\n",
        "\n",
        "\n",
        "    @staticmethod\n",
        "    def one_hot_encode(targets, num_classes):\n",
        "        targets = targets.long()\n",
        "        one_hot_targets = torch.zeros(targets.size(0), num_classes)\n",
        "        one_hot_targets.scatter_(1, targets.unsqueeze(1), 1.0)\n",
        "        return one_hot_targets\n",
        "        \n",
        "\n",
        "train_dataset = TextClassificationDataset(train_input_ids, train_attention_masks, train_labels)\n",
        "val_dataset = TextClassificationDataset(val_input_ids, val_attention_masks, val_labels)\n"
      ],
      "metadata": {
        "id": "7yfXP_F2zVcs"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### DataLoader\n",
        "*italicized text*"
      ],
      "metadata": {
        "id": "KQxJ0TU8vGT1"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)\n",
        "eval_dataloader = DataLoader(val_dataset, batch_size=8)"
      ],
      "metadata": {
        "id": "ZmUysho6O8kp"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "print(len(train_dataset))\n",
        "len((val_dataset))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "G6GOjnpbz0Na",
        "outputId": "594909ce-782f-434a-ffba-d75ca5fc8e84"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "3592\n"
          ]
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "898"
            ]
          },
          "metadata": {},
          "execution_count": 11
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "1.Revisiting dimension requirements for Transformers in Pytorch from Chapter 3: The encoder expects data with dimensions (seq_len, batch_size). However, Hugging Face's bert-based-uncased model requires data with dimensions (batch_size, seq_len). As a result, the output from the train_dataloader has dimensions of (batch_size, seq_len)."
      ],
      "metadata": {
        "id": "7i-QtecpxXAs"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "item=next(iter(train_dataloader))\n",
        "item_ids,item_mask,item_labels=item['input_ids'],item['attention_mask'],item['labels']\n",
        "print ('item_ids, ',item_ids.shape, '\\n',\n",
        "       'item_mask, ',item_mask.shape, '\\n',\n",
        "       'item_labels, ',item_labels.shape, '\\n',)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "QhpJsZniecpV",
        "outputId": "205e6918-cd87-4637-d098-21ddb13814ff"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "item_ids,  torch.Size([8, 512]) \n",
            " item_mask,  torch.Size([8, 512]) \n",
            " item_labels,  torch.Size([8, 2]) \n",
            "\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "model = AutoModelForSequenceClassification.from_pretrained(\"bert-base-uncased\", num_labels=2)\n",
        "optimizer = AdamW(model.parameters(), lr=5e-5)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "de0AgV2QaYy4",
        "outputId": "f3360699-0a0b-4319-9e17-9872b318a52b"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias']\n",
            "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
            "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
            "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
            "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
            "/Users/premtimsina/opt/anaconda3/envs/transformer_learn/lib/python3.11/site-packages/transformers/optimization.py:391: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
            "  warnings.warn(\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "ip22B5SW3Hs6"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# 5. Prepare Accelaerator\n",
        "What is Accelerator?\n",
        " 1. It provides an easy-to-use API for training deep learning models on various hardware accelerators, such as GPUs, TPUs, and Apple's Metal Performance Shaders (MPS).\n",
        "  * In our example, during training, we donot specifically select 'mps' device. THe accelerator automatically detects it and use 'mps' for training\n",
        " 2. The Accelerator library is particularly useful for distributed training and mixed-precision training."
      ],
      "metadata": {
        "id": "lPDCD6Em3R5n"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Prepare the model and optimizer\n",
        "accelerator = Accelerator()\n",
        "model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(\n",
        "    model, optimizer, train_dataloader, eval_dataloader\n",
        ")\n"
      ],
      "metadata": {
        "id": "GK9hKmD6OOoK"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# 5. Fine Tune The Model\n",
        "1. `lr_scheduler` in the provided code is an instance of a learning rate scheduler, which is responsible for adjusting the learning rate during the training process. The learning rate scheduler helps improve the training process by dynamically adjusting the learning rate based on the number of training steps. In this code, the learning rate starts with the initial value set in the optimizer and decreases linearly to 0 as the training progresses.\n",
        "2. Some benefit of lr_scheduler over optimizer alone are\n",
        "  * Faster convergence \n",
        "  * Avoid Overshooting: When using a fixed learning rate, the optimizer might overshoot the optimal solution, especially in the later stages of training. By decreasing the learning rate over time, the model can make smaller updates and fine-tune its weights\n",
        "  \n",
        "3. `progress_bar` is just utility to show the progress of training\n",
        "4. These are standard approach for fine tunning:\n",
        "```\n",
        " }\n",
        "        outputs = model(**batch)\n",
        "        loss = outputs.loss\n",
        "        accelerator.backward(loss)\n",
        "        optimizer.step()\n",
        "        lr_scheduler.step()\n",
        "        optimizer.zero_grad()\n",
        "        progress_bar.update(1)\n",
        "```\n",
        "  * each batch should be dictionary of structure {input_ids:torch.Tensor(), attention_mask: torch.Tensor(), labels: torch.Tensor()\n",
        "  * the dimension of input_ids=(batch_size, seq_len); attention_mask= (batch_size, seq_len); and labels=(batch_size,)\n",
        "  * You can notice that during training, we are not explicitly converting `tensor` into device; accelerator is automatically identifying the `device` and converting `tensor` into the appropriate format\n",
        "1. After each epoch, we are also printing the evaluation metrics over the evaluation dataset"
      ],
      "metadata": {
        "id": "e_3lyEAH34rv"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "\n",
        "num_epochs = 1\n",
        "num_training_steps = num_epochs * len(train_dataloader)\n",
        "lr_scheduler = get_scheduler(\n",
        "      \"linear\",\n",
        "      optimizer=optimizer,\n",
        "      num_warmup_steps=0,\n",
        "      num_training_steps=num_training_steps\n",
        "  )\n",
        "progress_bar = tqdm(range(num_training_steps))\n",
        "\n",
        "for epoch in range(num_epochs):\n",
        "    for batch in train_dataloader:\n",
        "        outputs = model(**batch)\n",
        "        loss = outputs.loss\n",
        "        accelerator.backward(loss)\n",
        "        optimizer.step()\n",
        "        lr_scheduler.step()\n",
        "        optimizer.zero_grad()\n",
        "        progress_bar.update(1)\n",
        "    model.eval()\n",
        "    device = 'mps'\n",
        "    preds = []\n",
        "    out_label_ids = []\n",
        "    epochs=1\n",
        "    epoch=1\n",
        "\n",
        "    for batch in eval_dataloader:\n",
        "        with torch.no_grad():\n",
        "            inputs = {k: v.to(device) for k, v in batch.items()}\n",
        "            outputs = model(**inputs)\n",
        "            logits = outputs.logits\n",
        "\n",
        "        preds.extend(torch.argmax(logits.detach().cpu(), dim=1).numpy())\n",
        "        out_label_ids.extend(torch.argmax(inputs[\"labels\"].detach().cpu(),dim=1).numpy())\n",
        "    accuracy = accuracy_score(out_label_ids, preds)\n",
        "    f1 = f1_score(out_label_ids, preds, average='weighted')\n",
        "    recall = recall_score(out_label_ids, preds, average='weighted')\n",
        "    precision = precision_score(out_label_ids, preds, average='weighted')\n",
        "\n",
        "    print(f\"Epoch {epoch + 1}/{num_epochs} Evaluation Results:\")\n",
        "    print(f\"Accuracy: {accuracy}\")\n",
        "    print(f\"F1 Score: {f1}\")\n",
        "    print(f\"Recall: {recall}\")\n",
        "    print(f\"Precision: {precision}\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "UaDCyPtoQlZR",
        "outputId": "8dc754e0-1cfc-48f2-827f-c899313dbbda"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
            "To disable this warning, you can either:\n",
            "\t- Avoid using `tokenizers` before the fork if possible\n",
            "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 449/449 [04:19<00:00,  1.71it/s]"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Epoch 2/1 Evaluation Results:\n",
            "Accuracy: 0.9977728285077951\n",
            "F1 Score: 0.9977724507291988\n",
            "Recall: 0.9977728285077951\n",
            "Precision: 0.9977820316957794\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "BAI2s6wuzMff"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# 6. Inference Pipeline\n",
        "1. `tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n",
        "`: You need use the same tokenizer that was use for fine-tunning\n",
        "2. `logits.detach().cpu()`\n",
        "  * `detach is done to prevent  unintentional back-propogation\n",
        "  * `.cpu` is done so that the output is compatible with scikit-learn libraries for further computation"
      ],
      "metadata": {
        "id": "fDte1szszhiK"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from transformers import BertTokenizer\n",
        "import torch\n",
        "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n",
        "\n",
        "def inference(text, model,  label, device='mps'):\n",
        "    # Load the tokenizer\n",
        "\n",
        "    # Tokenize the input text\n",
        "    inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True)\n",
        "    # Move input tensors to the specified device (default: 'cpu')\n",
        "    inputs = {k: v.to(device) for k, v in inputs.items()}\n",
        "\n",
        "    # Set the model to evaluation mode and perform inference\n",
        "    model.eval()\n",
        "    with torch.no_grad():\n",
        "        outputs = model(**inputs)\n",
        "        logits = outputs.logits\n",
        "\n",
        "    # Get the index of the predicted label\n",
        "    pred_label_idx = torch.argmax(logits.detach().cpu(), dim=1).item()\n",
        "\n",
        "    print(f\"Predicted label index: {pred_label_idx}, actual label {label}\")\n",
        "    return pred_label_idx\n"
      ],
      "metadata": {
        "id": "RKwJIhUTkja9"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "#https://abcnews.go.com/US/tornado-confirmed-delaware-powerful-storm-moves-east/story?id=98293454\n",
        "text='\\\n",
        "WASHINGTON (ABC) A confirmed tornado was located near Bridgeville in Sussex County, Delaware, shortly after 6 p.m. ET Saturday, moving east at 50 mph, according to the National Weather Service. Downed trees and wires were reported in the area.\\\n",
        "'\n",
        "inference(text, model, 1.0)\n",
        "text=\"this is definately junk text I am typing\"\n",
        "inference(text, model, 0.0)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "1D8y9AKRxfSS",
        "outputId": "40adf168-5d35-471d-da8a-aee57a1fb852"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Predicted label index: 1, actual label 1.0\n",
            "Predicted label index: 0, actual label 0.0\n"
          ]
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "0"
            ]
          },
          "metadata": {},
          "execution_count": 24
        }
      ]
    }
  ]
}